# src/scripts/inference.py
import os
import sys
import torch
import random
import argparse
import yaml

# --- Add project root to sys.path and import custom modules ---
try:
    # This assumes the script is run as a module from the project root
    from src.mad_datasets import MADTokenDataset
    from src.models import MADModel
except ImportError:
    print("Error: Make sure to run this script as a module from the project root.")
    print("Example: python -m src.scripts.inference")
    sys.exit(1)

def run_inference(config, model_weights_path, data_dir):
    """
    Loads a fine-tuned model and performs inference on a single random sample.
    """
    print("--- Starting Inference ---")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Initialize model and load weights ---
    print(f"Loading model weights from: {model_weights_path}")
    
    # Filter config to only include parameters needed by the MADModel constructor
    model_params = {k: v for k, v in config.items() if k in MADModel.get_param_keys()}
    model = MADModel(**model_params).to(device)
    
    if not os.path.exists(model_weights_path):
        print(f"Error: Model weights file not found at '{model_weights_path}'")
        return

    # Load the checkpoint
    checkpoint = torch.load(model_weights_path, map_location=device)
    
    # Extract model weights from the 'state_dict' key
    if 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        # Fallback for checkpoints that are just the state_dict
        model.load_state_dict(checkpoint)
        
    model.eval()
    print("Model loaded successfully.")

    # --- Load a single random sample from the dataset ---
    print(f"Loading a random sample from: {data_dir}")
    # Load the full dataset to randomly select one item
    dataset = MADTokenDataset(root_dir=data_dir, usage="all", num_classes=config['num_classes'])
    
    if len(dataset) == 0:
        print(f"Error: No data found in '{data_dir}'. Please ensure the path is correct and the directory is not empty.")
        return
        
    class_names = dataset.get_class_names()
    
    # Select a random index
    random_idx = random.randint(0, len(dataset) - 1)
    (x_img, x_sig), y_true = dataset[random_idx]
    
    print(f"Selected random sample index: {random_idx}")
    print(f"True Label: {class_names[y_true.item()]}")

    # --- Run Inference ---
    with torch.no_grad():
        # Add a batch dimension to the model input
        x_img = x_img.unsqueeze(0).to(device).float()
        x_sig = x_sig.unsqueeze(0).to(device).float()
        
        # Get model prediction
        logits = model(x_img, x_sig, for_supcon=False)
        probs = torch.softmax(logits, dim=1)
        prediction_idx = torch.argmax(probs, dim=1).item()
        
        predicted_class = class_names[prediction_idx]
        confidence = probs[0, prediction_idx].item()

    # --- Print Results ---
    print("\n--- Inference Result ---")
    print(f"Predicted Class: {predicted_class}")
    print(f"Confidence: {confidence:.2%}")
    print("------------------------")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference with a fine-tuned MAD model.")
    
    # --- Define Command-Line Arguments ---
    parser.add_argument('--weights_path', type=str, 
                        default="results/final_model/model_best_finetuned.pth.tar",
                        help="Path to the fine-tuned model weights (.pth.tar file).")
                        
    parser.add_argument('--data_dir', type=str, 
                        default="data/processed",
                        help="Path to the processed data directory.")
                        
    parser.add_argument('--config_path', type=str, 
                        default="results/final_model/config_finetune_used.yaml",
                        help="Path to the config file used for fine-tuning.")
    
    args = parser.parse_args()

    # --- Load Configuration File ---
    # The config file saved during fine-tuning is required to match the model architecture.
    config = {}
    if os.path.exists(args.config_path):
        try:
            with open(args.config_path, 'r') as f:
                config = yaml.safe_load(f)
            print(f"Loaded configuration from {args.config_path}")
        except Exception as e:
            print(f"Error loading YAML config file: {e}. Please check the file format.")
            sys.exit(1)
    else:
        print(f"Error: Config file not found at '{args.config_path}'.")
        print("This script requires the config file to correctly initialize the model architecture.")
        sys.exit(1)

    run_inference(config, args.weights_path, args.data_dir)